# coding = utf-8

import os
import unittest

import numpy as np
import pandas as pd

from PyMatterSim.reader.dump_reader import DumpReader
from PyMatterSim.reader.lammps_reader_helper import read_additions
from PyMatterSim.static.gr import conditional_gr, gr
from PyMatterSim.utils.logging import get_logger_handle

logger = get_logger_handle(__name__)

READ_TEST_FILE_PATH = "tests/sample_test_data"


class Testgr(unittest.TestCase):
    """
    Test class for gr
    """

    def setUp(self) -> None:
        super().setUp()
        self.test_file_unary = f"{READ_TEST_FILE_PATH}/unary.dump"
        self.test_file_binary = f"{READ_TEST_FILE_PATH}/dump_2D.atom"
        self.test_file_ternary = f"{READ_TEST_FILE_PATH}/ternary.dump"
        self.test_file_vector = f"{READ_TEST_FILE_PATH}/binary_velocity.dump"

    def test_gr_unary(self) -> None:
        """
        Test gr works properly for unary system
        """
        logger.info(f"Starting test gr using {self.test_file_unary}...")
        readdump = DumpReader(self.test_file_unary, ndim=3)
        readdump.read_onefile()
        gr(
            readdump.snapshots,
            ppp=np.array([1, 1, 1]),
            rdelta=0.01,
            outputfile="gr_unary.csv",
        ).getresults()

        result = pd.read_csv("gr_unary.csv")
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                1.025326,
                1.710015,
                0.785585,
                0.521454,
                0.729149,
                1.130114,
                1.146455,
                1.213279,
                0.65381,
                0.81293,
                1.06233,
                0.983969,
                0.91615,
                1.007661,
                1.089595,
                1.092679,
                0.944498,
            ],
            result["gr"].values[::50],
        )
        os.remove("gr_unary.csv")
        logger.info(f"Finishing test gr using {self.test_file_unary}...")

    def test_gr_binary(self) -> None:
        """
        Test gr works properly for binary system
        """
        logger.info(f"Starting test gr using {self.test_file_binary}...")
        readdump = DumpReader(self.test_file_binary, ndim=2)
        readdump.read_onefile()
        gr(
            readdump.snapshots,
            ppp=np.array([1, 1]),
            rdelta=0.01,
            outputfile="gr_binary.csv",
        ).getresults()

        result = pd.read_csv("gr_binary.csv")
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.72847,
                0.384933,
                0.806808,
                1.368034,
                1.011177,
                1.113767,
                0.876326,
                1.026788,
                0.946471,
                1.035938,
                0.986363,
                1.007043,
                1.050854,
                0.977112,
                1.010957,
                0.967467,
                0.998938,
                1.001044,
                1.017255,
                0.998411,
                0.982957,
                0.994798,
                1.003319,
                1.008359,
                0.9902,
                1.008315,
                0.995181,
                1.00977,
                1.010021,
                0.992355,
                0.99926,
                1.005825,
                1.003841,
                1.005463,
                0.997589,
                1.003351,
                0.995746,
                1.003872,
                0.998639,
                1.008067,
                1.000801,
                1.003405,
                1.013182,
                1.004504,
                0.999414,
                1.006322,
                1.003714,
                1.004043,
                0.998428,
                0.996152,
                0.996412,
                0.996975,
                1.000487,
                1.008384,
                0.996359,
                0.999228,
                0.993921,
                1.000036,
                1.008195,
                1.001124,
                0.998326,
                0.996889,
                0.998939,
                1.0031,
                0.999168,
                1.001377,
                0.993674,
                0.987245,
                1.004333,
                0.996842,
                1.003297,
                1.000263,
                1.009594,
                0.995659,
                1.002695,
                1.004935,
                1.00005,
                1.002332,
                1.005065,
                1.000328,
                0.999666,
                1.002931,
                1.002739,
                0.999046,
                1.003004,
                0.99989,
                0.999148,
                1.008478,
                1.00487,
                0.99853,
                1.004657,
                0.997811,
                0.995662,
                1.000433,
                1.005383,
                0.999889,
                1.001236,
                1.001308,
            ],
            result["gr"].values[::50],
        )
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.246259,
                0.279056,
                1.912521,
                0.956198,
                1.5563,
                0.642415,
                0.974198,
                0.80883,
                0.951376,
                1.099018,
                0.962408,
                1.194818,
                0.910821,
                1.035728,
                0.901313,
                1.00668,
                1.022224,
                1.027347,
                1.038793,
                0.954929,
                1.012682,
                0.977656,
                1.028737,
                0.989186,
                1.020376,
                0.991099,
                1.016226,
                1.011258,
                0.985662,
                1.005742,
                0.99062,
                1.006429,
                0.995175,
                1.002111,
                1.011688,
                0.993128,
                1.005526,
                0.989335,
                1.020691,
                1.00021,
                1.012953,
                1.010226,
                1.002088,
                0.993557,
                1.011567,
                1.001227,
                1.007893,
                0.997239,
                0.989359,
                0.989321,
                0.995924,
                0.999899,
                1.010076,
                0.997254,
                0.996328,
                0.997259,
                0.99389,
                1.013335,
                0.997992,
                1.00239,
                0.993294,
                0.999729,
                1.006275,
                0.993762,
                0.99435,
                0.988314,
                0.982369,
                1.002103,
                0.993634,
                1.004039,
                0.994128,
                1.01109,
                1.002113,
                1.004931,
                1.003625,
                0.998207,
                1.001837,
                1.010062,
                1.002151,
                0.999144,
                0.999071,
                1.003395,
                1.000782,
                1.005418,
                0.995606,
                0.989886,
                1.006352,
                1.003153,
                0.999898,
                1.007786,
                0.996982,
                1.000807,
                0.996207,
                1.00786,
                0.994609,
                1.011121,
                1.003596,
            ],
            result["gr11"].values[::50],
        )
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                1.639879,
                0.371371,
                1.439545,
                0.84156,
                1.060646,
                0.688139,
                1.054527,
                1.122313,
                1.035964,
                1.136337,
                0.86989,
                1.05462,
                0.919401,
                1.035545,
                0.993958,
                1.015732,
                1.007873,
                0.980788,
                1.015712,
                0.958909,
                1.011447,
                0.978577,
                1.031769,
                0.996932,
                0.992214,
                0.997713,
                0.996987,
                1.003919,
                1.010867,
                1.000174,
                0.991224,
                1.013815,
                1.011901,
                1.011734,
                0.995179,
                0.994621,
                0.995751,
                0.999541,
                1.003716,
                0.999439,
                0.997822,
                0.995148,
                1.010114,
                1.007731,
                1.004829,
                0.999794,
                1.002313,
                1.004436,
                1.006187,
                1.001408,
                1.001277,
                0.998826,
                0.997686,
                1.004194,
                0.996756,
                1.005617,
                0.987738,
                1.002183,
                1.006689,
                1.007054,
                1.000058,
                0.995468,
                0.99738,
                1.000996,
                1.008324,
                1.00675,
                1.001598,
                0.98779,
                1.005164,
                1.006629,
                1.003397,
                1.001745,
                1.009084,
                0.985849,
                1.006251,
                1.007812,
                1.005254,
                1.008275,
                1.002981,
                1.004293,
                1.002287,
                1.004508,
                1.004901,
                0.994901,
                1.001695,
                1.002898,
                1.00593,
                1.012758,
                1.013193,
                0.994502,
                0.999113,
                1.000636,
                0.994687,
                1.003254,
                1.004328,
                1.005054,
                0.994273,
                0.999574,
            ],
            result["gr12"].values[::50],
        )
        os.remove("gr_binary.csv")
        logger.info(f"Finishing test gr using {self.test_file_binary}...")

    def test_gr_ternary(self) -> None:
        """
        Test gr works properly for ternary system
        """
        logger.info(f"Starting test gr using {self.test_file_ternary}...")
        readdump = DumpReader(self.test_file_ternary, ndim=3)
        readdump.read_onefile()
        gr(readdump.snapshots, ppp=[1, 1, 1], rdelta=0.01, outputfile="gr_ternary.csv").getresults()

        result = pd.read_csv("gr_ternary.csv")
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                1.479613,
                1.772747,
                0.721019,
                0.594365,
                1.076223,
                1.154532,
                1.067021,
                0.89377,
                0.913816,
                1.028702,
                1.026303,
                1.041439,
                0.995372,
                1.022583,
                1.049341,
                1.013536,
                0.983176,
                0.999846,
                0.978523,
                0.998431,
                1.019369,
                1.013795,
                1.000775,
                0.999512,
                1.020727,
                0.995072,
                0.997777,
                1.008622,
                1.015689,
                0.986608,
                0.990026,
                1.013056,
                1.007252,
                1.000747,
                0.994512,
                1.010576,
                1.006734,
                0.997109,
                1.002819,
                0.988767,
                1.011061,
                0.997904,
                1.00065,
                1.014095,
                1.005123,
                0.999067,
                1.009004,
                1.001507,
            ],
            result.iloc[:, 1][::50].values,
        )  # gtotal
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.870949,
                2.044691,
                0.781527,
                0.543318,
                0.987394,
                1.163588,
                1.124287,
                0.89024,
                0.898273,
                1.011417,
                1.023626,
                1.031904,
                1.0142,
                1.001242,
                1.052359,
                1.01178,
                0.975759,
                1.00294,
                0.984987,
                1.014312,
                1.015111,
                1.009879,
                0.998306,
                0.994818,
                1.02898,
                0.997848,
                0.995766,
                1.012367,
                1.023713,
                0.979891,
                0.986575,
                1.013189,
                1.014389,
                0.998386,
                1.000141,
                1.008472,
                1.007082,
                1.001221,
                1.005988,
                0.982958,
                1.009273,
                0.986439,
                1.00632,
                1.007985,
                1.007037,
                0.996224,
                0.999403,
                1.000366,
            ],
            result.iloc[:, 2][::50].values,
        )  # g11
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                2.267606,
                0.0,
                0.0,
                0.0,
                2.80449,
                3.408212,
                0.469536,
                0.789201,
                2.017624,
                1.159918,
                1.010514,
                0.44411,
                1.967139,
                0.877377,
                1.259997,
                0.995056,
                1.805176,
                1.174907,
                0.967504,
                0.789859,
                1.000942,
                1.009588,
                0.702162,
                1.30584,
                1.082102,
                0.82159,
                1.538911,
                1.166519,
                0.992445,
                1.033355,
                0.742984,
                0.658398,
                1.537472,
                1.063682,
                1.346463,
                1.279997,
                0.50764,
                1.225524,
                1.046125,
                0.852195,
                1.29236,
                0.725944,
                0.927184,
                0.864268,
                1.208494,
                0.887554,
                1.006215,
                0.904771,
            ],
            result.iloc[:, 3][::50].values,
        )  # g22
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                1.700705,
                0.393944,
                1.158264,
                0.221778,
                1.402245,
                1.27808,
                1.17384,
                0.690551,
                1.597286,
                1.232413,
                1.326299,
                1.332331,
                1.229462,
                0.789639,
                0.944998,
                1.030594,
                0.999294,
                0.969298,
                0.99438,
                0.863908,
                1.137434,
                0.967522,
                1.111756,
                1.051927,
                0.879208,
                1.074387,
                0.947022,
                1.041534,
                1.096912,
                1.107167,
                1.068039,
                1.031491,
                0.872619,
                1.014437,
                0.935043,
                1.164442,
                0.956054,
                0.935269,
                1.069202,
                1.006472,
                1.130815,
                0.967925,
                1.030204,
                1.080335,
                1.001154,
                1.086684,
                1.006215,
                1.062579,
            ],
            result.iloc[:, 4][::50].values,
        )  # g33
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                4.068353,
                0.50981,
                0.136266,
                0.808837,
                1.608458,
                1.269726,
                0.759544,
                0.742777,
                1.058264,
                1.07463,
                1.003084,
                0.855566,
                0.890998,
                1.027047,
                1.00985,
                0.928161,
                0.917758,
                1.019405,
                0.945372,
                0.958284,
                0.958121,
                1.004639,
                1.016528,
                1.051927,
                0.990601,
                0.962859,
                0.948763,
                1.001507,
                0.958646,
                1.005857,
                0.995653,
                1.023745,
                1.002168,
                1.023127,
                0.966945,
                1.01856,
                0.980441,
                0.981747,
                0.96468,
                1.019868,
                1.013065,
                1.035933,
                0.992329,
                1.043658,
                1.02729,
                1.002012,
                1.050607,
                1.006263,
            ],
            result.iloc[:, 5][::50].values,
        )  # g12
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                2.567731,
                1.436737,
                0.766498,
                0.665334,
                1.092926,
                1.035829,
                1.015026,
                1.009713,
                0.820896,
                1.03625,
                1.03652,
                1.149462,
                0.925712,
                1.112204,
                1.077019,
                1.066131,
                1.037218,
                0.988304,
                0.956438,
                0.959736,
                1.075879,
                1.034333,
                0.983256,
                0.991115,
                1.009498,
                1.003752,
                1.017527,
                0.997422,
                1.017025,
                0.999344,
                1.005214,
                1.023745,
                0.98078,
                1.012699,
                0.982896,
                1.001828,
                1.025233,
                0.991707,
                1.00857,
                1.00215,
                1.007281,
                1.034747,
                0.981724,
                1.024048,
                0.982685,
                1.000673,
                1.03002,
                1.003169,
            ],
            result.iloc[:, 6][::50].values,
        )  # g13
        np.testing.assert_almost_equal(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                4.535213,
                0.393944,
                0.289566,
                1.774224,
                1.928087,
                0.852053,
                0.469536,
                0.690551,
                1.42915,
                1.522393,
                0.947357,
                1.332331,
                1.229462,
                1.272197,
                0.748123,
                0.959518,
                0.934823,
                0.793062,
                1.15563,
                1.036689,
                0.773455,
                1.072687,
                1.306801,
                0.906833,
                0.997563,
                0.884789,
                1.139386,
                0.91655,
                0.822684,
                0.959544,
                0.917121,
                0.801051,
                0.986891,
                0.748517,
                0.991146,
                1.02222,
                1.066043,
                0.927206,
                0.923052,
                0.933007,
                0.997365,
                0.974647,
                0.972255,
                1.006255,
                1.013002,
                1.092374,
                0.995278,
                0.978415,
            ],
            result.iloc[:, 7][::50].values,
        )  # g23
        os.remove("gr_ternary.csv")
        logger.info(f"Finishing test gr using {self.test_file_ternary}...")

    def test_gr_condition(self) -> None:
        """
        Test gr condition works properly
        """
        logger.info(f"Starting test conditional_gr...")
        readdump = DumpReader(self.test_file_ternary, ndim=3)
        readdump.read_onefile()
        snapshot = readdump.snapshots.snapshots[0]

        gr_selected = conditional_gr(snapshot, condition=snapshot.particle_type == 2)[["r", "gr"]].values
        gr_results = gr(readdump.snapshots).getresults()[["r", "gr"]].values
        np.testing.assert_almost_equal(gr_selected, gr_results)

        gr22_selected = conditional_gr(snapshot, condition=snapshot.particle_type == 2)[["r", "gA"]].values
        gr22_results = gr(readdump.snapshots).getresults()[["r", "gr22"]].values
        np.testing.assert_almost_equal(gr22_selected, gr22_results)

        vx = read_additions(self.test_file_vector, ncol=5)
        vy = read_additions(self.test_file_vector, ncol=6)
        vz = read_additions(self.test_file_vector, ncol=7)
        vector = np.vstack((vx, vy, vz)).T

        readdump = DumpReader(self.test_file_vector, ndim=3)
        readdump.read_onefile()
        snapshot = readdump.snapshots.snapshots[0]

        grresults_float = conditional_gr(snapshot, condition=vx.ravel())
        np.testing.assert_almost_equal(
            np.array(
                [
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    -0.68633941,
                    0.44707095,
                    1.14853632,
                    -0.18622456,
                    -0.26693063,
                    2.61779559,
                    -1.33441647,
                    0.27959398,
                    -0.06522355,
                    0.59165186,
                    0.40821972,
                    -0.09813158,
                    -0.1407156,
                    -0.73305988,
                    -0.0547678,
                    -0.25349853,
                    0.76305551,
                    0.20946121,
                    0.83117225,
                    -0.11123542,
                    -0.05548569,
                    0.2326811,
                    0.24431357,
                    -0.00415691,
                    -0.02648211,
                    0.16408667,
                    0.28909996,
                    -0.28155182,
                    -0.18776073,
                    -0.04934026,
                    0.32763147,
                    0.26464412,
                    -0.19693263,
                    0.07655692,
                    -0.13725776,
                    0.11123452,
                    -0.40416159,
                    0.48699466,
                    0.19801818,
                    -0.09087962,
                    -0.12902979,
                    0.10347054,
                    -0.01646532,
                    -0.57267094,
                    0.01568593,
                    -0.20689161,
                    -0.36446847,
                    -0.04910013,
                    0.06117577,
                ]
            ),
            (grresults_float["gA"] / grresults_float["gr"]).values[::50],
        )

        grresults_vector = conditional_gr(snapshot, condition=vector, conditiontype="vector")
        np.testing.assert_almost_equal(
            np.array(
                [
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    -0.9315299,
                    2.1246504,
                    -0.42817632,
                    3.9757448,
                    -2.02032924,
                    1.56137977,
                    -3.15865012,
                    -2.06905277,
                    -1.68972952,
                    1.49695568,
                    0.95324952,
                    0.23841881,
                    -0.17596337,
                    0.68157285,
                    1.00162379,
                    -0.15607883,
                    0.35245528,
                    2.78187328,
                    0.05600847,
                    -0.47184815,
                    0.13917716,
                    0.88740995,
                    0.55349047,
                    -0.21617755,
                    0.28791209,
                    0.23008911,
                    -0.05526148,
                    0.43465378,
                    -1.34250606,
                    -0.36927568,
                    0.41785121,
                    0.18696626,
                    -0.38810259,
                    0.24924333,
                    0.93155147,
                    -0.56219654,
                    -0.8035333,
                    0.0806967,
                    -0.43397226,
                    -0.92165167,
                    0.27119987,
                    0.1445809,
                    -0.71100812,
                    -0.32399405,
                    0.15859521,
                    -0.14847221,
                    -0.34353328,
                    -0.34054109,
                    -0.63861419,
                ]
            ),
            (grresults_vector["gA"] / grresults_vector["gr"]).values[::50],
        )
